{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pairwise Image Registration\n",
    "\n",
    "This tutorial is an introduction to using the deepali library for the spatial alignment of two images."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's begin with the common imports used throughout this tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from typing import Any, Callable, Optional, Tuple, Type, cast\n",
    "\n",
    "from IPython.utils import io\n",
    "from matplotlib.figure import Figure\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "import torch\n",
    "from torch import Tensor, optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    from deepali.core.environ import cuda_visible_devices\n",
    "except ImportError:\n",
    "    if not os.getenv(\"COLAB_RELEASE_TAG\"):\n",
    "        raise\n",
    "    !git clone https://github.com/BioMedIA/deepali.git && pip install ./deepali\n",
    "    from deepali.core.environ import cuda_visible_devices"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Choose CUDA device to use if available. By default, this tutorial runs on the CPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "# Use first device specified in CUDA_VISIBLE_DEVICES if CUDA is available\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() and cuda_visible_devices() else \"cpu\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset\n",
    "\n",
    "The images used in this tutorial are from the public [MNIST dataset](https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html) available through `torchvision`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.datasets import MNIST\n",
    "from torchvision.transforms import ToTensor\n",
    "\n",
    "with io.capture_output() as captured:  # type: ignore[reportPrivateImportUsage]\n",
    "    mnist = MNIST(root=\"data\", download=True, transform=ToTensor())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we define a utility function named ``imshow`` for displaying the example images used in this tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def imshow(\n",
    "    image: Tensor,\n",
    "    label: Optional[str] = None,\n",
    "    ax: Optional[plt.Axes] = None,\n",
    "    **kwargs,\n",
    ") -> None:\n",
    "    r\"\"\"Render image data in last two tensor dimensions using matplotlib.pyplot.imshow().\n",
    "\n",
    "    Args:\n",
    "        image: Image tensor of shape ``(..., H, W)``.\n",
    "        ax: Figure axes to render the image in. If ``None``, a new figure is created.\n",
    "        label: Image label to display in the axes title.\n",
    "        kwargs: Keyword arguments to pass on to ``matplotlib.pyplot.imshow()``.\n",
    "            When ``ax`` is ``None``, can contain ``figsize`` to specify the size of\n",
    "            the figure created for displaying the image.\n",
    "\n",
    "    \"\"\"\n",
    "    if ax is None:\n",
    "        figsize = kwargs.pop(\"figsize\", (4, 4))\n",
    "        _, ax = plt.subplots(figsize=figsize)\n",
    "    kwargs[\"cmap\"] = kwargs.get(\"cmap\", \"gray\")\n",
    "    ax.imshow(image.reshape((-1,) + image.shape[-2:])[0].cpu().numpy(), **kwargs)\n",
    "    if label:\n",
    "        ax.set_title(label, fontsize=16, y=1.04)\n",
    "    ax.get_xaxis().set_visible(False)\n",
    "    ax.get_yaxis().set_visible(False)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now display a random selection of the MNIST dataset to select an example pair of images to register."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "digits = tuple(range(10))\n",
    "\n",
    "n_rows = len(digits)\n",
    "n_cols = 10\n",
    "\n",
    "rng = torch.Generator().manual_seed(0)\n",
    "perm = torch.randperm(len(mnist), generator=rng)\n",
    "\n",
    "samples: dict[int, list[Tensor]] = {digit: [] for digit in digits}\n",
    "for i in perm:\n",
    "    image, digit = mnist[i]\n",
    "    samples[digit].append(image)\n",
    "    if all(len(samples[digit]) >= n_cols for digit in digits):\n",
    "        break\n",
    "samples = {digit: images[:n_cols] for digit, images in samples.items()}\n",
    "\n",
    "_, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols, n_rows), tight_layout=True)\n",
    "for row, digit in zip(axes, digits):\n",
    "    for ax, image in zip(row, samples[digit]):\n",
    "        imshow(image, ax=ax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target = samples[9][0].float()\n",
    "source = samples[9][2].float()\n",
    "\n",
    "_, axes = plt.subplots(1, 2, figsize=(8, 4), tight_layout=True)\n",
    "\n",
    "imshow(target, \"target\", ax=axes[0])\n",
    "imshow(source, \"source\", ax=axes[1])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's now try and align the `source` image of our chosen digit with the `target` image.\n",
    "\n",
    "First, we define a spatial transformation model which determines the type of spatial transformation that we want to apply. As we have chosen two examples with a different rotation, we select an `EulerRotation` from the `deepali.spatial` library. This transformation has a single parameter in 2D, namely the rotation angle. Each spatial transformation is defined with respect to normalized coordinates as used by [torch.nn.functional.grid_sample()](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html). Every spatial transform with base type `SpatialTransform` maps points given with respect to these coordinates to normalized coordinates defined with respect to the same domain, i.e., the domain and codomain of the normalized coordinate map is the same. As these normalized coordinates range are in [-1, 1] for the extent of each image side, the origin of this coordinate system is in the image center. The rotation is thus also with respect to the image center.\n",
    "\n",
    "For images sampled on a regular grid, we use the `Grid` class defined by `deepali.core` to convert between different coordinate systems. An instance of `Grid` is also used by the spatial transform to define its domain and codomain, respectively. Commonly, this is the sampling grid of the fixed target image, though it need not be. It could also be a subimage region or a reference space if both images were symmetrically mapped to this common space. For medical images, the `Grid` defines the mapping between image element indices, world coordinates, and the normalized coordinates. The position of the image within the world is defined by the `Grid.center()` (or `Grid.origin()`), its orientation by the `Grid.direction()` cosines matrix, and the extent of each image element in world units by `Grid.spacing()` (e.g., millimeters).\n",
    "\n",
    "In case of MNIST, we simply use a world coordinate system which is identical to the image space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from deepali.core import Grid\n",
    "\n",
    "grid = Grid(shape=target.shape[1:])\n",
    "\n",
    "print(\"size:     \", list(grid.size()))\n",
    "print(\"origin:   \", grid.origin().tolist())\n",
    "print(\"center:   \", grid.center().tolist())\n",
    "print(\"spacing:  \", grid.spacing().tolist())\n",
    "print(\"direction:\", grid.direction().tolist())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<details>\n",
    "<summary>Images in world space</summary>\n",
    "\n",
    "The objective of image registration is traditionally with respect to their alignment within the world coordinate system. This is especially important when aligning images from different sources or with point clouds (e.g., from depth sensors). In deep learning applications, the alignment is often carried out with respect to the normalized coordinates and the image to world map is ignored. However, when exporting a spatial transform for visualization or to use with classic registration tools such as [ANTs](https://antspy.readthedocs.io/en/latest/registration.html#), [ITK](https://itk.org/Doxygen413/html/RegistrationPage.html), [MIRTK](http://mirtk.github.io/), [NiftyReg](http://cmictig.cs.ucl.ac.uk/wiki/index.php/NiftyReg) and others, we need to recover the correct world information and also convert the spatial transform to a world coordinate map. For this purpose, the `Grid` and more specifically `FlowField` tensor type of the `deepali.data` library will come in handy, though we will cover this in another tutorial.\n",
    "\n",
    "</details>\n",
    "</br>\n",
    "<details>\n",
    "<summary>Grid center vs. origin</summary>\n",
    "\n",
    "Notice that `grid` stores the position of the image center as attribute, from which the image origin, i.e., the world coordinates of the sampling point with all zero indices, is computed using the grid spacing and direction. This simplifies grid resizing operations, including increasing (e.g., `Grid.upsample()`) and decreasing (e.g., `Grid.downsample()`) the number of image sampling points along each spatial grid dimension. We will go into more details on the different coordinate systems and relation between them in a separate tutorial when working with medical image volumes.</details>\n",
    "\n",
    "</details>\n",
    "</br>\n",
    "<details>\n",
    "<summary>Grid size vs. shape</summary>\n",
    "\n",
    "The `Grid` properties are defined with respect to world coordinate axes in the order `x`, `y`, etc. The `Grid.shape` property, however, is in reverse order to match the ordering of the spatial dimensions in the corresponding image data tensor which has shape `(C, ..., Y, X)`, where `C` is the number of image channels, and `..., Y, X` is the size of the image and thus sampling grid along the respective spatial dimension. The `Grid.size()` and `size` argument of the `Grid` init function specify the size of each spatial dimension in the original order, i.e., `x`, `y`, etc. The `torch.Tensor.size()` uses the ordering of `Grid.shape`. In order to not confuse the different ordering of spatial dimensions, it is adviced to prefer `torch.Tensor.shape` over `torch.Tensor.size()`, which is also consistent with the use of `numpy.ndarray.shape` and `Grid.shape`.\n",
    "\n",
    "</details>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Grid search\n",
    "\n",
    "To begin, let's try a brute force [grid search](https://en.wikipedia.org/wiki/Hyperparameter_optimization#Grid_search) of the optimal angle which minimizes a chosen objective. The objective used measures the similarity of the fixed target and moving source images as proxy for assessing the quality of the spatial alignment. For the grid search, we manually set the rotation angle to different values, evaluate our chosen similarity measure, and keep a record of the current optimal angle which thus far attained the minimum value.\n",
    "\n",
    "<details>\n",
    "<summary>Similarity or loss?</summary>\n",
    "\n",
    "Though in image registration we often refer to the data fidelty term of the objective function as similarity measure, by convention the value reflects a dissimilarity or spatial alignment error rather which we want to minimize. A similarity measure which increases in value with better spatial alignment of the images may be negated to minimize this function. Furthermore, to align with the terminology used in deep learning, all objective function terms are named losses and thereby defined in the `deepali.losses` library.\n",
    "\n",
    "</details>\n",
    "\n",
    "Here, we select the mean squared error (MSE) as our similarity measure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from deepali.losses import functional as L\n",
    "\n",
    "sim_loss = L.mse_loss"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Spatial transformation modules are defined by the `deepali.spatial` library."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import deepali.spatial as spatial"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With a normalized coordinate map given in the form of a `SpatialTranform`, we need an operation which applies this spatial transformation to the moving source image. There are different functions in `deepali` which can be used for this, which are all built on `torch.nn.functional.grid_sample()`. These are mainly the `grid_sample()` function of the functional `deepali.core` API, the image sampling modules defined in `deepali.modules`, and the spatial transformer modules defined in `deepali.spatial`. The latter can be directly combined with a spatial transform without having to work with the normalized coordinates explicitly. Unlike a `SpatialTransform`, a `SpatialTransformer` takes as input a tensor representing the data to which the spatial transformation should be applied. The `ImageTransformer` more specifically takes as input an image batch tensor of shape `(N, C, ..., Y, X)` and samples it at the spatially transformed grid points of the fixed target image domain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_batch = target.unsqueeze(0)  # (N, C, Y, X)\n",
    "source_batch = source.unsqueeze(0)  # (N, C, Y, X)\n",
    "\n",
    "rotation = spatial.EulerRotation(grid, params=False)\n",
    "transformer = spatial.ImageTransformer(rotation)\n",
    "\n",
    "angle_space = torch.linspace(-180.0, 180.0, 360)\n",
    "best_angle_deg = torch.tensor(0.0)\n",
    "min_loss_value = torch.tensor(torch.inf)\n",
    "\n",
    "bar_format = \"{l_bar}{bar}{rate_fmt}{postfix}\"\n",
    "for angle_deg in (pbar := tqdm(angle_space, bar_format=bar_format)):\n",
    "    rotation.angles_(angle_deg.deg2rad().reshape(1, 1))\n",
    "    warped_batch = transformer(source_batch)\n",
    "    loss = sim_loss(warped_batch, target_batch)\n",
    "    if loss.lt(min_loss_value):\n",
    "        best_angle_deg = angle_deg\n",
    "        min_loss_value = loss\n",
    "        pbar.set_postfix(dict(loss=loss.item(), angle=angle_deg.item()))\n",
    "\n",
    "rotation.angles_(best_angle_deg.deg2rad().reshape(1, 1))\n",
    "warped: Tensor = transformer(source_batch)[0]\n",
    "\n",
    "fig, axes = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True)\n",
    "\n",
    "imshow(target, \"target\", ax=axes[0])\n",
    "imshow(warped, \"warped\", ax=axes[1])\n",
    "imshow(source, \"source\", ax=axes[2])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For the grid search, we used `params=False` in the init function of the spatial transform to indicate that the parameters of the transform (the Euler angle in case of the chosen 2D rotation) are non-optimizable. With this setting, the `SpatialTransform.parameters()` function of the `torch.nn.Module` subclass returns an iterator over an empty collection and the `params` property of the spatial transform is of type `torch.Tensor`. With the default setting of `params=True`, the `params` property is of type `torch.nn.parameter.Parameter` instead. We will use it in the next section."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gradient descent\n",
    "\n",
    "While for this toy example with a single parameter, grid search is a viable option, the number of objective function evalutions grows exponentially with the number of parameters of our spatial transform, and moreover grid search requires a discretization of our parameter space. We can use the gradient descent optimization normally used in PyTorch in fitting a neural network to a training set also to perform a [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent) to optimize the image alignment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_batch = target.unsqueeze(0)  # (N, C, Y, X)\n",
    "source_batch = source.unsqueeze(0)  # (N, C, Y, X)\n",
    "\n",
    "rotation = spatial.EulerRotation(grid)\n",
    "transformer = spatial.ImageTransformer(rotation)\n",
    "optimizer = optim.Adam(transformer.parameters(), lr=1e-2)\n",
    "\n",
    "iterations = 100\n",
    "\n",
    "bar_format = \"{l_bar}{bar}{rate_fmt}{postfix}\"\n",
    "for _ in (pbar := tqdm(range(iterations), bar_format=bar_format)):\n",
    "    warped_batch = transformer(source_batch)\n",
    "    loss = sim_loss(warped_batch, target_batch)\n",
    "    angle_deg = rotation.angles().rad2deg().item()\n",
    "    pbar.set_postfix(dict(loss=loss.item(), angle=angle_deg))\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "with torch.inference_mode():\n",
    "    warped: Tensor = transformer(source_batch)[0]\n",
    "\n",
    "fig, axes = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True)\n",
    "\n",
    "imshow(target, \"target\", ax=axes[0])\n",
    "imshow(warped, \"warped\", ax=axes[1])\n",
    "imshow(source, \"source\", ax=axes[2])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Unfortunately, for this given pair of images and the chosen optimizer and hyperparameters such as learning rate and momentum, the registration resulted in a suboptimal solution compared to our previous grid search result. What happened?\n",
    "\n",
    "Gradient descent is a local optimization technique and sensitive to the initialization. Furthermore, the binary nature of the images with a lot of homogeneous regions (background mostly) and the chosen similarity measure result in an insufficient signal (capture range) from our objective in which way to adjust our rotation at each gradient step in order to minimize the loss and thereby maximize alignment. Increasing and decreasing the angle by a little bit changes our objective value similarly, so it is difficult to judge which is the better direction to follow.\n",
    "\n",
    "## Multi-scale registration\n",
    "\n",
    "The way this is overcome in traditional optimization based image registration is via a multi-scale or multi-resolution optimization scheme, respectively. For this, we first downsample the images a number of times, find the optimal rotation between those lower resolution images with a wider spatial extent of each image element, and continue with the next higher resolution with the solution found at the lower resolution. To generate a multi-scale representation of our images, Gaussian blurring is commonly applied. This is also to obey the [Nyquist-Shannon sampling theorem](https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem) when subsampling the image signal. The resulting pyramid of differently sized images is referred to as [Gaussian image pyramid](https://en.wikipedia.org/wiki/Pyramid_(image_processing)#Gaussian_pyramid). In deep learning based image registration, a multi-scale representation of the input images is usually obtained by the employed neural network.\n",
    "\n",
    "<details>\n",
    "<summary>Sampling grid pyramid</summary>\n",
    "\n",
    "Notice that the `Grid.pyramid()` function can be used to construct the sampling grids for the different scales of the image pyramid. The corresponding function of the `Image` tensor type of the `deepali.data` library generates both, the data tensors of the different pyramid scales and assigns the respective sampling grid with the images at each scale.\n",
    "\n",
    "</details>\n",
    "</br>\n",
    "<details>\n",
    "<summary>Multi-scale vs. multi-resolution</summary>\n",
    "\n",
    "Note that multi-scale optimization in case of a rigid transformation such as a rotation would already by achieved by simply using a multi-scale representation of the images by blurring the images with consecutively larger Gaussian kernels (or consecutively with the same kernel) without also reducing the image size itself. In non-rigid registration, the images downsampled mainly for two reasons: a) to reduce the resolution of the image deformation field, and b) to reduce the computational cost at lower scales. In `deepali.spatial` the non-rigid transformations such as `DisplacementFieldTransform` and `FreeFormDeformation` have a `stride` parameter which can be used to reduce the size of the non-rigid transformation without changing the image sampling grid. Here, we are using a standard multi-resolution image pyramid with downsampling.\n",
    "</details>\n",
    "</br>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from deepali.data import Image\n",
    "\n",
    "levels = 3\n",
    "\n",
    "target_pyramid = Image(target, grid).pyramid(levels)\n",
    "source_pyramid = Image(source, grid).pyramid(levels)\n",
    "\n",
    "fig, axes = plt.subplots(2, levels, figsize=(4 * levels, 8), tight_layout=True)\n",
    "\n",
    "for (level, tgt), src in zip(target_pyramid.items(), source_pyramid.values()):\n",
    "    imshow(tgt, f\"target, level {level}\", ax=axes[0, level])\n",
    "    imshow(src, f\"source, level {level}\", ax=axes[1, level])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With these image pyramids, we can implement a sequential multi-resolution registration as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ImagePyramid = dict[int, Image]\n",
    "LossFunction = Callable[[Tensor, Tensor, spatial.SpatialTransform], Tensor | dict[str, Tensor]]\n",
    "TransformCls = str | Type[spatial.SpatialTransform]\n",
    "TransformArg = TransformCls | Tuple[TransformCls, dict[str, Any]]\n",
    "OptimizerCls = str | Type[optim.Optimizer]\n",
    "OptimizerArg = OptimizerCls | Tuple[OptimizerCls, dict[str, Any]]\n",
    "\n",
    "\n",
    "def image_pyramid(\n",
    "    image: Tensor | Image | ImagePyramid,\n",
    "    levels: int,\n",
    "    grid: Optional[Grid] = None,\n",
    "    device: Optional[torch.device] = None,\n",
    ") -> ImagePyramid:\n",
    "    r\"\"\"Consruct image pyramid from image tensor.\"\"\"\n",
    "    if isinstance(image, dict):\n",
    "        pyramid = {}\n",
    "        for level, im in image.items():\n",
    "            if type(level) is not int:\n",
    "                raise TypeError(\"Image pyramid key values must be int\")\n",
    "            if level >= levels:\n",
    "                break\n",
    "            if type(im) is Tensor:\n",
    "                im = Image(im, grid)\n",
    "            if not isinstance(im, Image):\n",
    "                raise TypeError(\"Image pyramid key values must be deepali.data.Image or torch.Tensor\")\n",
    "            im = cast(Image, im.float().to(device))\n",
    "            pyramid[level] = im\n",
    "        if len(pyramid) < levels:\n",
    "            raise ValueError(f\"Expected image pyramid with {levels} levels, but only got {len(pyramid)} levels\")\n",
    "    else:\n",
    "        if not isinstance(image, Image):\n",
    "            image = Image(image, grid)\n",
    "        image = cast(Image, image.float().to(device))\n",
    "        pyramid = image.pyramid(levels)\n",
    "    return pyramid\n",
    "\n",
    "\n",
    "def init_transform(transform: TransformArg, grid: Grid, device: Optional[torch.device] = None) -> spatial.SpatialTransform:\n",
    "    r\"\"\"Auxiliary functiont to create spatial transform.\"\"\"\n",
    "    if isinstance(transform, tuple):\n",
    "        cls, args = transform\n",
    "    else:\n",
    "        cls = transform\n",
    "        args = {}\n",
    "    if isinstance(cls, str):\n",
    "        spatial_transform = spatial.new_spatial_transform(cls, grid, **args)\n",
    "    else:\n",
    "        spatial_transform = cls(grid, **args)\n",
    "    return spatial_transform.to(device).train()\n",
    "\n",
    "\n",
    "def init_optimizer(optimizer: OptimizerArg, transform: spatial.SpatialTransform) -> optim.Optimizer:\n",
    "    r\"\"\"Auxiliary function to initialize optimizer.\"\"\"\n",
    "    if isinstance(optimizer, tuple):\n",
    "        cls, args = optimizer\n",
    "    else:\n",
    "        cls = optimizer\n",
    "        args = {}\n",
    "    if isinstance(cls, str):\n",
    "        cls = getattr(optim, cls)\n",
    "    if not issubclass(cls, optim.Optimizer):\n",
    "        raise TypeError(\"'optimizer' must be a torch.optim.Optimizer\")\n",
    "    return cls(transform.parameters(), **args)\n",
    "\n",
    "\n",
    "def multi_resolution_registration(\n",
    "    target: Tensor | Image | ImagePyramid,\n",
    "    source: Tensor | Image | ImagePyramid,\n",
    "    loss_fn: LossFunction,\n",
    "    transform: TransformArg,\n",
    "    optimizer: OptimizerArg,\n",
    "    iterations: int | list[int] = 100,\n",
    "    levels: int = 3,\n",
    "    device: Optional[str | int | torch.device] = None,\n",
    ") -> spatial.SpatialTransform:\n",
    "    r\"\"\"Multi-resolution pairwise image registration.\"\"\"\n",
    "    if device is None:\n",
    "        if isinstance(target, dict):\n",
    "            device = next(iter(target.values())).device\n",
    "        else:\n",
    "            device = target.device\n",
    "    device = torch.device(f\"cuda:{device}\" if type(device) is int else device)\n",
    "    target = image_pyramid(target, levels=levels, device=device)\n",
    "    levels = len(target)\n",
    "    source = image_pyramid(source, levels=levels, device=device)\n",
    "    model = init_transform(transform, target[levels - 1].grid(), device=device)\n",
    "    bar_format = \"{l_bar}{bar}{rate_fmt}{postfix}\"\n",
    "    if isinstance(iterations, int):\n",
    "        iterations = [iterations]\n",
    "    iterations = list(iterations)\n",
    "    iterations += [iterations[-1]] * (levels - len(iterations))\n",
    "    for level, steps in zip(reversed(range(levels)), iterations):\n",
    "        model.grid_(target[level].grid())\n",
    "        target_batch = target[level].batch().tensor()\n",
    "        source_batch = source[level].batch().tensor()\n",
    "        transformer = spatial.ImageTransformer(model)\n",
    "        optim = init_optimizer(optimizer, model)\n",
    "        for _ in (pbar := tqdm(range(steps), bar_format=bar_format)):\n",
    "            warped_batch: Tensor = transformer(source_batch)\n",
    "            loss = loss_fn(warped_batch, target_batch, model)\n",
    "            if isinstance(loss, Tensor):\n",
    "                loss = dict(loss=loss)\n",
    "            pbar.set_description(f\"Level {level}\")\n",
    "            pbar.set_postfix({k: v.item() for k, v in loss.items()})\n",
    "            optim.zero_grad()\n",
    "            loss[\"loss\"].backward()\n",
    "            optim.step()\n",
    "    return model.eval()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see what we now obtain for the rigid registration of our two digits using this registration function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = multi_resolution_registration(\n",
    "    target=target,\n",
    "    source=source,\n",
    "    transform=spatial.EulerRotation,\n",
    "    optimizer=(optim.Adam, {\"lr\": 1e-2}),\n",
    "    loss_fn=lambda a, b, _: sim_loss(a, b),\n",
    "    device=device,\n",
    ")\n",
    "transform = transform.cpu()\n",
    "\n",
    "with torch.inference_mode():\n",
    "    transformer = spatial.ImageTransformer(transform)\n",
    "    warped: Tensor = transformer(source)\n",
    "\n",
    "fig, axes = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True)\n",
    "\n",
    "imshow(target, \"target\", ax=axes[0])\n",
    "imshow(warped, \"warped\", ax=axes[1])\n",
    "imshow(source, \"source\", ax=axes[2])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great! The multi-resolution gradient descent resulted in a similar solution as the previous exhaustive grid search."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Non-rigid registration\n",
    "\n",
    "Multi-resolution optimization becomes even more important when the spatial transform is non-rigid. Apart from linearly interpolated dense vector fields, a free-form deformation based on a cubic B-spline parameterization is commonly employed in medical image registration. This non-rigid transformation model has the advantage that first and second order derivatives can be computed exactly. Using our previously defined `multi_resolution_registration()` function, we can alternatively optimize a non-rigid deformation such as a `spatial.FreeFormDeformation` (FFD)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(\n",
    "    w_curvature: float = 0,\n",
    "    w_diffusion: float = 0,\n",
    "    w_bending: float = 0,\n",
    ") -> Callable[[Tensor, Tensor, spatial.SpatialTransform], dict[str, Tensor]]:\n",
    "    r\"\"\"Construct loss function for free-form deformation (FFD) based image registration.\n",
    "\n",
    "    Args:\n",
    "        w_curvature: Weight of curvature, i.e., sum of unmixed first order derivatives.\n",
    "            When the spatial transform is parameterized by velocities, the curvature of\n",
    "            the velocity vector field is computed.\n",
    "        w_bending: Weight of bending energy, i.e., sum of second order derivatives.\n",
    "\n",
    "    Returns:\n",
    "        Loss function which takes as input a registered image pair, and the spatial transform\n",
    "        used to register the images. The loss function evaluates the alignment of the images\n",
    "        based on a similarity term and optional regularization terms (transform penalties).\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    def loss(\n",
    "        warped: Tensor,\n",
    "        target: Tensor,\n",
    "        transform: spatial.SpatialTransform,\n",
    "    ) -> dict[str, Tensor]:\n",
    "        terms: dict[str, Tensor] = {}\n",
    "        # Similarity term\n",
    "        sim = sim_loss(warped, target)\n",
    "        terms[\"sim\"] = sim\n",
    "        loss = sim\n",
    "        # Regularization terms\n",
    "        # v_or_u: dense velocity or displacement vector field, respectively.\n",
    "        v_or_u = getattr(transform, \"v\", getattr(transform, \"u\", None))\n",
    "        assert v_or_u is not None\n",
    "        if w_curvature > 0:\n",
    "            curvature = L.curvature_loss(v_or_u)\n",
    "            loss = curvature.mul(w_curvature).add(loss)\n",
    "            terms[\"curv\"] = curvature\n",
    "        if w_diffusion > 0:\n",
    "            diffusion = L.diffusion_loss(v_or_u)\n",
    "            loss = diffusion.mul(w_diffusion).add(loss)\n",
    "            terms[\"diff\"] = diffusion\n",
    "        if w_bending > 0:\n",
    "            if isinstance(transform, spatial.BSplineTransform):\n",
    "                params = transform.params\n",
    "                assert isinstance(params, Tensor)\n",
    "                bending = L.bspline_bending_loss(params)\n",
    "            else:\n",
    "                bending = L.bending_loss(v_or_u)\n",
    "            loss = bending.mul(w_bending).add(loss)\n",
    "            terms[\"be\"] = bending\n",
    "        return {\"loss\": loss, **terms}\n",
    "\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = multi_resolution_registration(\n",
    "    target=target_pyramid,\n",
    "    source=source_pyramid,\n",
    "    transform=(\"FFD\", {\"stride\": 2}),\n",
    "    optimizer=(\"Adam\", {\"lr\": 1e-2}),\n",
    "    loss_fn=loss_fn(w_bending=1e-5),\n",
    "    device=device,\n",
    ")\n",
    "transform = transform.cpu()\n",
    "\n",
    "with torch.inference_mode():\n",
    "    transformer = spatial.ImageTransformer(transform)\n",
    "    warped: Tensor = transformer(source)\n",
    "\n",
    "fig, axes = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True)\n",
    "\n",
    "imshow(target, \"target\", ax=axes[0])\n",
    "imshow(warped, \"warped\", ax=axes[1])\n",
    "imshow(source, \"source\", ax=axes[2])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In addition to the choice of parameterization of the spatial transform (i.e., vector field given by a cubic B-spline with control points at every second image grid point), we added a loss term (`bspline_bending_loss()`) which penalizes bending of the spline. By setting the weight of this term `w_bending` to higher values, we can make the spatial transform stiffer and allow for less severe deformation. In order to assess the quality of the computed deformation, we can visualize the deformation of the image grid. Because of the small size of the MNIST images, we generate a higher resolution grid image for this visualization, which we then deform using an `ImageTransformer` with an input `source` and output `target` grid matching this higher resolution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from deepali.core import functional as U\n",
    "\n",
    "\n",
    "grid_highres = grid.resize(512)\n",
    "grid_image = U.grid_image(grid_highres, num=1, stride=8, inverted=True)\n",
    "grid_transformer = spatial.ImageTransformer(transform, grid_highres, padding=\"zeros\")\n",
    "\n",
    "with torch.inference_mode():\n",
    "    warped_grid: Tensor = grid_transformer(grid_image)\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(8, 4), tight_layout=True)\n",
    "\n",
    "imshow(grid_image, \"source grid\", ax=axes[0])\n",
    "imshow(warped_grid, \"warped grid\", ax=axes[1])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Although the deformed source image matches well our target image, we realize that our deformation is not well behaved. This can be remedied by increasing the weight of the bending loss (`w_bending`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = multi_resolution_registration(\n",
    "    target=target_pyramid,\n",
    "    source=source_pyramid,\n",
    "    transform=(spatial.FreeFormDeformation, {\"stride\": 2}),\n",
    "    optimizer=(optim.Adam, {\"lr\": 1e-2}),\n",
    "    loss_fn=loss_fn(w_bending=1e-3),\n",
    "    device=device,\n",
    ")\n",
    "transform = transform.cpu()\n",
    "\n",
    "with torch.inference_mode():\n",
    "    image_transformer = spatial.ImageTransformer(transform)\n",
    "    grid_transformer = spatial.ImageTransformer(transform, grid_highres, padding=\"zeros\")\n",
    "    warped_grid: Tensor = grid_transformer(grid_image)\n",
    "    warped: Tensor = image_transformer(source)\n",
    "\n",
    "fig, axes = plt.subplots(1, 4, figsize=(16, 4), tight_layout=True)\n",
    "\n",
    "imshow(target, \"target\", ax=axes[0])\n",
    "imshow(warped, \"warped\", ax=axes[1])\n",
    "imshow(source, \"source\", ax=axes[2])\n",
    "imshow(warped_grid, \"deformation\", ax=axes[3])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Diffeomorphic registration\n",
    "\n",
    "In the previous non-rigid registration example, we used a free form deformation model as our spatial transformation. Though with a good choice of the control point spacing (`stride`) and bending regularization weight (`w_bending`) we ended up with a smooth and invertible spatial transform, sometimes it is desireable to have the property of being able to invert a spatial transform with both the forward and backward map being differentiable, baked into the chosen transformation model. Typically, a diffeomorphic coordinate map is parameterized either by a time-varying or stationary velocity vector field (SVF). The latter is more commonly found in the recent registration literature because of its computational efficiency. We can use the scaling and squaring algorithm to compute the exponential map of this vector field to compute the dense displacement field generated by this velocity field.\n",
    "\n",
    "The `spatial.StationaryVelocityFieldTransform` (SVF) and `spatial.StationaryVelocityFreeFormDeformation` (SVFFD) are diffeomorphic transformation models based on such stationary velocity field. We can use the `SpatialTransform.inverse()` funtion to obtain the inverse deformation. With this, we can define the following auxiliary function to plot the result of a diffeomorphic registration, where the source image is deformed by the computed forward transform and the target image is deformed with its inverse, respectively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def invertible_registration_figure(\n",
    "    target: Tensor,\n",
    "    source: Tensor,\n",
    "    transform: spatial.SpatialTransform,\n",
    ") -> Figure:\n",
    "    r\"\"\"Create figure visualizing result of diffeomorphic registration.\n",
    "\n",
    "    Args:\n",
    "        target: Fixed target image.\n",
    "        source: Moving source image.\n",
    "        transform: Invertible spatial transform, i.e., must implement ``SpatialTransform.inverse()``.\n",
    "\n",
    "    Returns:\n",
    "        Instance of ``matplotlib.pyplot.Figure``.\n",
    "\n",
    "    \"\"\"\n",
    "    device = transform.device\n",
    "\n",
    "    highres_grid = transform.grid().resize(512)\n",
    "    grid_image = U.grid_image(highres_grid, num=1, stride=8, inverted=True, device=device)\n",
    "\n",
    "    with torch.inference_mode():\n",
    "        inverse = transform.inverse()\n",
    "\n",
    "        source_transformer = spatial.ImageTransformer(transform)\n",
    "        target_transformer = spatial.ImageTransformer(inverse)\n",
    "\n",
    "        source_grid_transformer = spatial.ImageTransformer(transform, highres_grid, padding=\"zeros\")\n",
    "        target_grid_transformer = spatial.ImageTransformer(inverse, highres_grid, padding=\"zeros\")\n",
    "\n",
    "        warped_source: Tensor = source_transformer(source.to(device))\n",
    "        warped_target: Tensor = target_transformer(target.to(device))\n",
    "\n",
    "        warped_source_grid: Tensor = source_grid_transformer(grid_image)\n",
    "        warped_target_grid: Tensor = target_grid_transformer(grid_image)\n",
    "\n",
    "    fig, axes = plt.subplots(2, 3, figsize=(12, 8), tight_layout=True)\n",
    "\n",
    "    imshow(target, \"target\", ax=axes[0, 0])\n",
    "    imshow(warped_source, \"warped source\", ax=axes[0, 1])\n",
    "    imshow(warped_source_grid, \"forward deformation\", ax=axes[0, 2])\n",
    "\n",
    "    imshow(source, \"source\", ax=axes[1, 0])\n",
    "    imshow(warped_target, \"warped target\", ax=axes[1, 1])\n",
    "    imshow(warped_target_grid, \"inverse deformation\", ax=axes[1, 2])\n",
    "\n",
    "    return fig"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use the previously implemented `multi_resolution_registration` helper to optimize a diffeomorphic transformation. Here, we selected the SVFFD with a control point at every other image grid point. Note that the `stride` is in number of grid points. This way, the resolution of the control point grid is defined by the resolution of the target image at the current level of the mult-resolution optimization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = multi_resolution_registration(\n",
    "    target=target_pyramid,\n",
    "    source=source_pyramid,\n",
    "    transform=(spatial.StationaryVelocityFreeFormDeformation, {\"stride\": 2}),\n",
    "    optimizer=(optim.Adam, {\"lr\": 1e-2}),\n",
    "    loss_fn=loss_fn(w_bending=1e-3),\n",
    "    device=device,\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The visualization shows both good alignment between the deformed source image and target, as well as good alignment between the target deformed by the inverse and the input source image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = invertible_registration_figure(target, source, transform)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Circle to C deformation\n",
    "\n",
    "A classic example for demonstrating a diffeomorphic registration is the \"circle to C\" problem, where an image of a circle is deformed to align it with the image of the letter \"C\". This example has been used in the [DARTEL paper](https://doi.org/10.1016/j.neuroimage.2007.07.007) by John Ashburner, one of the seminal works which first introduced this parameterization for use in medical image registration. We can generate a similar example, though at present with a more open and thus more challenging target \"C\" shape, using functions `cshape_image()` and `circle_image()` of the `deepali.core` library. Gaussian blurring is applied to smooth the edges of these synthetic binary images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 2, figsize=(8, 4))\n",
    "\n",
    "circle_c_grid = Grid((128, 128))\n",
    "target_cshape = U.cshape_image(circle_c_grid, radius=50, sigma=2, dtype=torch.float)\n",
    "source_circle = U.circle_image(circle_c_grid, radius=50, sigma=2, dtype=torch.float)\n",
    "\n",
    "imshow(target_cshape, \"c-shape\", ax=axes[0])\n",
    "imshow(source_circle, \"circle\", ax=axes[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "circle_to_c_transform = multi_resolution_registration(\n",
    "    target=target_cshape[0],\n",
    "    source=source_circle[0],\n",
    "    loss_fn=loss_fn(w_diffusion=1e-3),\n",
    "    transform=spatial.StationaryVelocityFreeFormDeformation,\n",
    "    optimizer=(optim.Adam, {\"lr\": 1e-2}),\n",
    "    iterations=[200, 200, 300, 100],\n",
    "    levels=4,\n",
    "    device=device,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = invertible_registration_figure(target_cshape, source_circle, circle_to_c_transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "deepali",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
